{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ChjuaQjm_iBf"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "uWqCArLO_kez"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ikhIvrku-i-L"
      },
      "source": [
        "# Deep \u0026 Cross Network (DCN)\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/recommenders/examples/dcn\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/recommenders/blob/main/docs/examples/dcn.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/recommenders/blob/main/docs/examples/dcn.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/recommenders/docs/examples/dcn.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q-rOX95bAye4"
      },
      "source": [
        "This tutorial demonstrates how to use Deep \u0026 Cross Network (DCN) to effectively learn feature crosses.\n",
        "\n",
        "##Background\n",
        "\n",
        "**What are feature crosses and why are they important?** Imagine that we are building a recommender system to sell a blender to customers. Then, a customer's past purchase history such as `purchased_bananas` and `purchased_cooking_books`, or geographic features, are single features. If one has purchased both bananas **and** cooking books, then this customer will more likely click on the recommended blender. The combination of `purchased_bananas` and `purchased_cooking_books` is referred to as a **feature cross**, which provides additional interaction information beyond the individual features.\n",
        "\u003cdiv\u003e\n",
        "\u003ccenter\u003e\n",
        "\u003cimg src=\"https://github.com/tensorflow/recommenders/blob/main/assets/cross_features.gif?raw=true\" width=\"600\"/\u003e\n",
        "\u003c/center\u003e\n",
        "\u003c/div\u003e\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "**What are the challenges in learning feature crosses?** In Web-scale applications, data are mostly categorical, leading to large and sparse feature space. Identifying effective feature crosses in this setting often requires\n",
        "manual feature engineering or exhaustive search. Traditional feed-forward multilayer perceptron (MLP) models are universal function approximators; however, they cannot efficiently approximate even 2nd or 3rd-order feature crosses [[1](https://arxiv.org/pdf/2008.13535.pdf), [2](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/18fa88ad519f25dc4860567e19ab00beff3f01cb.pdf)].\n",
        "\n",
        "**What is Deep \u0026 Cross Network (DCN)?** DCN was designed to learn explicit and bounded-degree cross features more effectively. It starts with an input layer (typically an embedding layer), followed by a *cross network* containing multiple cross layers that models explicit feature interactions, and then combines\n",
        "with a *deep network* that models implicit feature interactions.\n",
        "\n",
        "\n",
        "*   Cross Network. This is the core of DCN. It explicitly applies feature crossing at each layer, and the highest\n",
        "polynomial degree increases with layer depth. The following figure shows the $(i+1)$-th cross layer.\n",
        "\u003cdiv class=\"fig figcenter fighighlight\"\u003e\n",
        "\u003ccenter\u003e\n",
        "  \u003cimg src=\"https://github.com/tensorflow/recommenders/blob/main/assets/feature_crossing.png?raw=true\" width=\"50%\" style=\"display:block\"\u003e\n",
        "  \u003c/center\u003e\n",
        "\u003c/div\u003e\n",
        "*   Deep Network. It is a traditional feedforward multilayer perceptron (MLP).\n",
        "\n",
        "The deep network and cross network are then combined to form DCN [[1](https://arxiv.org/pdf/2008.13535.pdf)]. Commonly, we could stack a deep network on top of the cross network (stacked structure); we could also place them in parallel (parallel structure). \n",
        "\n",
        "\n",
        "\u003cdiv class=\"fig figcenter fighighlight\"\u003e\n",
        "\u003ccenter\u003e\n",
        "  \u003cimg src=\"https://github.com/tensorflow/recommenders/blob/main/assets/parallel_deep_cross.png?raw=true\" hspace=\"40\" width=\"30%\" style=\"margin: 0px 100px 0px 0px;\"\u003e\n",
        "  \u003cimg src=\"https://github.com/tensorflow/recommenders/blob/main/assets/stacked_deep_cross.png?raw=true\" width=\"20%\"\u003e\n",
        "  \u003c/center\u003e\n",
        "\u003c/div\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6OlIGoADAhZg"
      },
      "source": [
        "In the following, we will first show the advantage of DCN with a toy example, and then we will walk you through some common ways to utilize DCN using the MovieLen-1M dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Az-y_3qxH3gA"
      },
      "source": [
        "Let's first install and import the necessary packages for this colab.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PjfZWVEWAmxS"
      },
      "outputs": [],
      "source": [
        "!pip install -q tensorflow-recommenders\n",
        "!pip install -q --upgrade tensorflow-datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DqsyLA0UHeCl"
      },
      "outputs": [],
      "source": [
        "import pprint\n",
        "\n",
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt\n",
        "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "import tensorflow_recommenders as tfrs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tHCgOeoHBGbb"
      },
      "source": [
        "## Toy Example\n",
        "To illustrate the benefits of DCN, let's work through a simple example. Suppose we have a dataset where we're trying to model the likelihood of a customer clicking on a blender Ad, with its features and label described as follows.\n",
        "\n",
        "| Features / Label        | Description           | Value Type / Range  |\n",
        "| ------------- |-------------| -----|\n",
        "| $x_1$ = country | the country this customer lives in | Int in [0, 199] |\n",
        "| $x_2$ = bananas | # bananas the customer has purchased |Int in   [0, 23] |\n",
        "| $x_3$ = cookbooks | # cooking books the customer has purchased     |Int in  [0, 5] |\n",
        "| $y$ | the likelihood of clicking on a blender Ad     |   -- |\n",
        "\n",
        "Then, we let the data follow the following underlying distribution:\n",
        "$$y = f(x_1, x_2, x_3) = 0.1x_1 + 0.4x_2+0.7x_3 + 0.1x_1x_2+3.1x_2x_3+0.1x_3^2$$\n",
        "\n",
        "where the likelihood $y$ depends linearly both on features $x_i$'s, but also on multiplicative interactions between the $x_i$'s. In our case, we would say that the likelihood of purchasing a blender ($y$) depends not just on buying bananas ($x_2$) or cookbooks ($x_3$), but also on buying bananas and cookbooks *together* ($x_2x_3$)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HO6d-0zoHrz8"
      },
      "source": [
        "We can generate the data for this as follows:\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-fi2ya4P_hab"
      },
      "source": [
        "### Synthetic data generation\n",
        "\n",
        "We first define $f(x_1, x_2, x_3)$ as described above. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9rT3f6C3GX0u"
      },
      "outputs": [],
      "source": [
        "def get_mixer_data(data_size=100_000, random_seed=42):\n",
        "  # We need to fix the random seed\n",
        "  # to make colab runs repeatable.\n",
        "  rng = np.random.RandomState(random_seed)\n",
        "  country = rng.randint(200, size=[data_size, 1]) / 200.\n",
        "  bananas = rng.randint(24, size=[data_size, 1]) / 24.\n",
        "  cookbooks = rng.randint(6, size=[data_size, 1]) / 6.\n",
        "\n",
        "  x = np.concatenate([country, bananas, cookbooks], axis=1)\n",
        "\n",
        "  # # Create 1st-order terms.\n",
        "  y = 0.1 * country + 0.4 * bananas + 0.7 * cookbooks\n",
        "\n",
        "  # Create 2nd-order cross terms.\n",
        "  y += 0.1 * country * bananas + 3.1 * bananas * cookbooks + (\n",
        "        0.1 * cookbooks * cookbooks)\n",
        "\n",
        "  return x, y"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JXtUSs9E3uRG"
      },
      "source": [
        "Let's generate the data that follows the distribution, and split the data into 90% for training and 10% for testing."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vrQWVYajgmNV"
      },
      "outputs": [],
      "source": [
        "x, y = get_mixer_data()\n",
        "num_train = 90000\n",
        "train_x = x[:num_train]\n",
        "train_y = y[:num_train]\n",
        "eval_x = x[num_train:]\n",
        "eval_y = y[num_train:]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MszQC-KJLhVK"
      },
      "source": [
        "### Model construction\n",
        "\n",
        "We're going to try out both cross network and deep network to illustrate the advantage a cross network can bring to recommenders. As the data we just created only contains 2nd-order feature interactions, it would be sufficient to illustrate with a single-layered cross network. If we wanted to model higher-order feature interactions, we could stack multiple cross layers and use a multi-layered cross network. The two models we will be building are:\n",
        "1.   Cross Network with only one cross layer;\n",
        "2.   Deep Network with wider and deeper ReLU layers. \n",
        "\n",
        "We first build a unified model class whose loss is the mean squared error. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bwgAH2FTR4Fe"
      },
      "outputs": [],
      "source": [
        "class Model(tfrs.Model):\n",
        "\n",
        "  def __init__(self, model):\n",
        "    super().__init__()\n",
        "    self._model = model\n",
        "    self._logit_layer = tf.keras.layers.Dense(1)\n",
        "\n",
        "    self.task = tfrs.tasks.Ranking(\n",
        "      loss=tf.keras.losses.MeanSquaredError(),\n",
        "      metrics=[\n",
        "        tf.keras.metrics.RootMeanSquaredError(\"RMSE\")\n",
        "      ]\n",
        "    )\n",
        "\n",
        "  def call(self, x):\n",
        "    x = self._model(x)\n",
        "    return self._logit_layer(x)\n",
        "\n",
        "  def compute_loss(self, features, training=False):\n",
        "    x, labels = features\n",
        "    scores = self(x)\n",
        "\n",
        "    return self.task(\n",
        "        labels=labels,\n",
        "        predictions=scores,\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QAKBq2QxOdM0"
      },
      "source": [
        "Then, we specify the cross network (with 1 cross layer of size 3) and the ReLU-based DNN (with layer sizes [512, 256, 128]):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EwBwSHz_N3pW"
      },
      "outputs": [],
      "source": [
        "crossnet = Model(tfrs.layers.dcn.Cross())\n",
        "deepnet = Model(\n",
        "    tf.keras.Sequential([\n",
        "      tf.keras.layers.Dense(512, activation=\"relu\"),\n",
        "      tf.keras.layers.Dense(256, activation=\"relu\"),\n",
        "      tf.keras.layers.Dense(128, activation=\"relu\")\n",
        "    ])\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2EtaI5lh4X0B"
      },
      "source": [
        "### Model training\n",
        "Now that we have the data and models ready, we are going to train the models. We first shuffle and batch the data to prepare for model training.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X6gD-NTF4eoj"
      },
      "outputs": [],
      "source": [
        "train_data = tf.data.Dataset.from_tensor_slices((train_x, train_y)).batch(1000)\n",
        "eval_data = tf.data.Dataset.from_tensor_slices((eval_x, eval_y)).batch(1000)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JYm5bmmgPVZu"
      },
      "source": [
        "Then, we define the number of epochs as well as the learning rate."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nFhrC7fV6szW"
      },
      "outputs": [],
      "source": [
        "epochs = 100\n",
        "learning_rate = 0.4"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zbRiVPJtPz-1"
      },
      "source": [
        "Alright, everything is ready now and let's compile and train the models. You could set `verbose=True` if you want to see how the model progresses."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F8ZXXbmKuB8p"
      },
      "outputs": [],
      "source": [
        "crossnet.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate))\n",
        "crossnet.fit(train_data, epochs=epochs, verbose=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Tzg3KLKW2sdA"
      },
      "outputs": [],
      "source": [
        "deepnet.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate))\n",
        "deepnet.fit(train_data, epochs=epochs, verbose=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xxWfaY6H7Bmp"
      },
      "source": [
        "### Model evaluation\n",
        "We verify the model performance on the evaluation dataset and report the Root Mean Squared Error (RMSE, the lower the better)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l4PM-goX6FoD"
      },
      "outputs": [],
      "source": [
        "crossnet_result = crossnet.evaluate(eval_data, return_dict=True, verbose=False)\n",
        "print(f\"CrossNet(1 layer) RMSE is {crossnet_result['RMSE']:.4f} \"\n",
        "      f\"using {crossnet.count_params()} parameters.\")\n",
        "\n",
        "deepnet_result = deepnet.evaluate(eval_data, return_dict=True, verbose=False)\n",
        "print(f\"DeepNet(large) RMSE is {deepnet_result['RMSE']:.4f} \"\n",
        "      f\"using {deepnet.count_params()} parameters.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6_Ig-Gnm7-JD"
      },
      "source": [
        "We see that the cross network achieved **magnitudes lower RMSE** than a ReLU-based DNN, with **magnitudes fewer parameters**. This has suggested the efficieny of a cross network in learning feaure crosses.  "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XsCsY1-Us-_e"
      },
      "source": [
        "### Model understanding\n",
        "We already know what feature crosses are important in our data, it would be fun to check whether our model has indeed learned the important feature cross. This can be done by visualizing the learned weight matrix in DCN. The weight $W_{ij}$ represents the learned importance of interaction between feature $x_i$ and $x_j$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N8dga2Qck5IV"
      },
      "outputs": [],
      "source": [
        "mat = crossnet._model._dense.kernel\n",
        "features = [\"country\", \"purchased_bananas\", \"purchased_cookbooks\"]\n",
        "\n",
        "plt.figure(figsize=(9,9))\n",
        "im = plt.matshow(np.abs(mat.numpy()), cmap=plt.cm.Blues)\n",
        "ax = plt.gca()\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "cax = divider.append_axes(\"right\", size=\"5%\", pad=0.05)\n",
        "plt.colorbar(im, cax=cax)\n",
        "cax.tick_params(labelsize=10) \n",
        "_ = ax.set_xticklabels([''] + features, rotation=45, fontsize=10)\n",
        "_ = ax.set_yticklabels([''] + features, fontsize=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bQHVZTu03qvi"
      },
      "source": [
        "Darker colours represent stronger learned interactions - in this case, it's clear that the model learned that purchasing babanas and cookbooks together is important.\n",
        "\n",
        "If you are interested in trying out more complicated synthetic data, feel free to check out [this paper](https://arxiv.org/pdf/2008.13535.pdf)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0wU4FcpfHCZM"
      },
      "source": [
        "## Movielens 1M example\n",
        "We now examine the effectiveness of DCN on a real-world dataset: Movielens 1M [[3](https://grouplens.org/datasets/movielens)]. Movielens 1M is a popular dataset for recommendation research. It predicts users' movie ratings given user-related features and movie-related features. We use this dataset to demonstrate some common ways to utilize DCN."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8Rvlem07wfwH"
      },
      "source": [
        "### Data processing\n",
        "\n",
        "The data processing procedure follows a similar procedure as the [basic ranking tutorial](https://www.tensorflow.org/recommenders/examples/basic_ranking)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7Y_n3EPosR4A"
      },
      "outputs": [],
      "source": [
        "ratings = tfds.load(\"movie_lens/100k-ratings\", split=\"train\")\n",
        "ratings = ratings.map(lambda x: {\n",
        "    \"movie_id\": x[\"movie_id\"],\n",
        "    \"user_id\": x[\"user_id\"],\n",
        "    \"user_rating\": x[\"user_rating\"],\n",
        "    \"user_gender\": int(x[\"user_gender\"]),\n",
        "    \"user_zip_code\": x[\"user_zip_code\"],\n",
        "    \"user_occupation_text\": x[\"user_occupation_text\"],\n",
        "    \"bucketized_user_age\": int(x[\"bucketized_user_age\"]),\n",
        "})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2Yb3KxrgSHiF"
      },
      "source": [
        "Next, we randomly split the data into 80% for training and 20% for testing.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a5-l91jR_zEo"
      },
      "outputs": [],
      "source": [
        "tf.random.set_seed(42)\n",
        "shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)\n",
        "\n",
        "train = shuffled.take(80_000)\n",
        "test = shuffled.skip(80_000).take(20_000)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MRHGa9mESMVz"
      },
      "source": [
        "Then, we create vocabulary for each feature."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l9qhEcHq_VfI"
      },
      "outputs": [],
      "source": [
        "feature_names = [\"movie_id\", \"user_id\", \"user_gender\", \"user_zip_code\",\n",
        "                 \"user_occupation_text\", \"bucketized_user_age\"]\n",
        "\n",
        "vocabularies = {}\n",
        "\n",
        "for feature_name in feature_names:\n",
        "  vocab = ratings.batch(1_000_000).map(lambda x: x[feature_name])\n",
        "  vocabularies[feature_name] = np.unique(np.concatenate(list(vocab)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Eti8kNkPSORk"
      },
      "source": [
        "### Model construction\n",
        "\n",
        "The model architecture we will be building starts with an embedding layer, which is fed into a cross network followed by a deep network. The embedding dimension is set to 32 for all the features. You could also use different embedding sizes for different features."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6lrDcBjiwnHU"
      },
      "outputs": [],
      "source": [
        "class DCN(tfrs.Model):\n",
        "\n",
        "  def __init__(self, use_cross_layer, deep_layer_sizes, projection_dim=None):\n",
        "    super().__init__()\n",
        "\n",
        "    self.embedding_dimension = 32\n",
        "\n",
        "    str_features = [\"movie_id\", \"user_id\", \"user_zip_code\",\n",
        "                    \"user_occupation_text\"]\n",
        "    int_features = [\"user_gender\", \"bucketized_user_age\"]\n",
        "\n",
        "    self._all_features = str_features + int_features\n",
        "    self._embeddings = {}\n",
        "\n",
        "    # Compute embeddings for string features.\n",
        "    for feature_name in str_features:\n",
        "      vocabulary = vocabularies[feature_name]\n",
        "      self._embeddings[feature_name] = tf.keras.Sequential(\n",
        "          [tf.keras.layers.StringLookup(\n",
        "              vocabulary=vocabulary, mask_token=None),\n",
        "           tf.keras.layers.Embedding(len(vocabulary) + 1,\n",
        "                                     self.embedding_dimension)\n",
        "    ])\n",
        "      \n",
        "    # Compute embeddings for int features.\n",
        "    for feature_name in int_features:\n",
        "      vocabulary = vocabularies[feature_name]\n",
        "      self._embeddings[feature_name] = tf.keras.Sequential(\n",
        "          [tf.keras.layers.IntegerLookup(\n",
        "              vocabulary=vocabulary, mask_value=None),\n",
        "           tf.keras.layers.Embedding(len(vocabulary) + 1,\n",
        "                                     self.embedding_dimension)\n",
        "    ])\n",
        "\n",
        "    if use_cross_layer:\n",
        "      self._cross_layer = tfrs.layers.dcn.Cross(\n",
        "          projection_dim=projection_dim,\n",
        "          kernel_initializer=\"glorot_uniform\")\n",
        "    else:\n",
        "      self._cross_layer = None\n",
        "\n",
        "    self._deep_layers = [tf.keras.layers.Dense(layer_size, activation=\"relu\")\n",
        "      for layer_size in deep_layer_sizes]\n",
        "\n",
        "    self._logit_layer = tf.keras.layers.Dense(1)\n",
        "\n",
        "    self.task = tfrs.tasks.Ranking(\n",
        "      loss=tf.keras.losses.MeanSquaredError(),\n",
        "      metrics=[tf.keras.metrics.RootMeanSquaredError(\"RMSE\")]\n",
        "    )\n",
        "\n",
        "  def call(self, features):\n",
        "    # Concatenate embeddings\n",
        "    embeddings = []\n",
        "    for feature_name in self._all_features:\n",
        "      embedding_fn = self._embeddings[feature_name]\n",
        "      embeddings.append(embedding_fn(features[feature_name]))\n",
        "\n",
        "    x = tf.concat(embeddings, axis=1)\n",
        "\n",
        "    # Build Cross Network\n",
        "    if self._cross_layer is not None:\n",
        "      x = self._cross_layer(x)\n",
        "    \n",
        "    # Build Deep Network\n",
        "    for deep_layer in self._deep_layers:\n",
        "      x = deep_layer(x)\n",
        "\n",
        "    return self._logit_layer(x)\n",
        "\n",
        "  def compute_loss(self, features, training=False):\n",
        "    labels = features.pop(\"user_rating\")\n",
        "    scores = self(features)\n",
        "    return self.task(\n",
        "        labels=labels,\n",
        "        predictions=scores,\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jDiRfzwVW9LH"
      },
      "source": [
        "### Model training\n",
        "We shuffle, batch and cache the training and test data. \n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qeFjmfUbgzcS"
      },
      "outputs": [],
      "source": [
        "cached_train = train.shuffle(100_000).batch(8192).cache()\n",
        "cached_test = test.batch(4096).cache()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5adSI3yOt2VQ"
      },
      "source": [
        "Let's define a function that runs a model multiple times and returns the model's RMSE mean and standard deviation out of multiple runs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gTDk3GloquHO"
      },
      "outputs": [],
      "source": [
        "def run_models(use_cross_layer, deep_layer_sizes, projection_dim=None, num_runs=5):\n",
        "  models = []\n",
        "  rmses = []\n",
        "\n",
        "  for i in range(num_runs):\n",
        "    model = DCN(use_cross_layer=use_cross_layer,\n",
        "                deep_layer_sizes=deep_layer_sizes,\n",
        "                projection_dim=projection_dim)\n",
        "    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate))\n",
        "    models.append(model)\n",
        "\n",
        "    model.fit(cached_train, epochs=epochs, verbose=False)\n",
        "    metrics = model.evaluate(cached_test, return_dict=True)\n",
        "    rmses.append(metrics[\"RMSE\"])\n",
        "\n",
        "  mean, stdv = np.average(rmses), np.std(rmses)\n",
        "\n",
        "  return {\"model\": models, \"mean\": mean, \"stdv\": stdv}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZRHjQ8g2h2-k"
      },
      "source": [
        "We set some hyper-parameters for the models. Note that these hyper-parameters are set globally for all the models for demonstration purpose. If you want to obtain the best performance for each model, or conduct a fair comparison among models, then we'd suggest you to fine-tune the hyper-parameters. Remember that the model architecture and optimization schemes are intertwined."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zy3kWb5Dh0E7"
      },
      "outputs": [],
      "source": [
        "epochs = 8\n",
        "learning_rate = 0.01"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Nz3ftiQLXdC0"
      },
      "source": [
        "**DCN (stacked).** We first train a DCN model with a stacked structure, that is, the inputs are fed to a cross network followed by a deep network.\n",
        "\u003cdiv\u003e\n",
        "\u003ccenter\u003e\n",
        "\u003cimg src=\"https://github.com/tensorflow/recommenders/blob/main/assets/stacked_structure.png?raw=true\" width=\"140\"/\u003e\n",
        "\u003c/center\u003e\n",
        "\u003c/div\u003e\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hiuYPJWhgw3J"
      },
      "outputs": [],
      "source": [
        "dcn_result = run_models(use_cross_layer=True,\n",
        "                        deep_layer_sizes=[192, 192])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZwTn_UpDX_iO"
      },
      "source": [
        "**Low-rank DCN.** To reduce the training and serving cost, we leverage low-rank techniques to approximate the DCN weight matrices. The rank is passed in through argument `projection_dim`; a smaller `projection_dim` results in a lower cost. Note that `projection_dim` needs to be smaller than (input size)/2 to reduce the cost. In practice, we've observed using low-rank DCN with rank (input size)/4 consistently preserved the accuracy of a full-rank DCN.\n",
        "\n",
        "\u003cdiv\u003e\n",
        "\u003ccenter\u003e\n",
        "\u003cimg src=\"https://github.com/tensorflow/recommenders/blob/main/assets/low_rank_dcn.png?raw=true\" width=\"400\"/\u003e\n",
        "\u003c/center\u003e\n",
        "\u003c/div\u003e\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NYxbHI7ZNJX7"
      },
      "outputs": [],
      "source": [
        "dcn_lr_result = run_models(use_cross_layer=True,\n",
        "                           projection_dim=20,\n",
        "                           deep_layer_sizes=[192, 192])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5O5AoNOdaQ80"
      },
      "source": [
        "**DNN.** We train a same-sized DNN model as a reference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iBPpwD4cGtXF"
      },
      "outputs": [],
      "source": [
        "dnn_result = run_models(use_cross_layer=False,\n",
        "                        deep_layer_sizes=[192, 192, 192])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cBY0ljpl3_k5"
      },
      "source": [
        "We evaluate the model on test data and report the mean and standard deviation out of 5 runs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a1yj3pp0glEL"
      },
      "outputs": [],
      "source": [
        "print(\"DCN            RMSE mean: {:.4f}, stdv: {:.4f}\".format(\n",
        "    dcn_result[\"mean\"], dcn_result[\"stdv\"]))\n",
        "print(\"DCN (low-rank) RMSE mean: {:.4f}, stdv: {:.4f}\".format(\n",
        "    dcn_lr_result[\"mean\"], dcn_lr_result[\"stdv\"]))\n",
        "print(\"DNN            RMSE mean: {:.4f}, stdv: {:.4f}\".format(\n",
        "    dnn_result[\"mean\"], dnn_result[\"stdv\"]))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K076UbT1nnq3"
      },
      "source": [
        "We see that DCN achieved better performance than a same-sized DNN with ReLU layers. Moreover, the low-rank DCN was able to reduce parameters while maintaining the accuracy."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eSF0gNLGX1Za"
      },
      "source": [
        "**More on DCN.** Besides what've been demonstrated above, there are more creative yet practically useful ways to utilize DCN [[1](https://arxiv.org/pdf/2008.13535.pdf)]. \n",
        "\n",
        "*   *DCN with a parallel structure*.  The inputs are fed in parallel to a cross network and a deep network.\n",
        "\n",
        "*   *Concatenating cross layers.* The inputs are fed in parallel to multiple cross layers to capture complementary feature crosses.\n",
        "\n",
        "\u003cdiv class=\"fig figcenter fighighlight\"\u003e\n",
        "\u003ccenter\u003e\n",
        "  \u003cimg src=\"https://github.com/tensorflow/recommenders/blob/main/assets/alternate_dcn_structures.png?raw=true\" hspace=40 width=\"600\" style=\"display:block;\"\u003e\n",
        "  \u003cdiv class=\"figcaption\"\u003e\n",
        "  \u003cb\u003eLeft\u003c/b\u003e: DCN with a parallel structure; \u003cb\u003eRight\u003c/b\u003e: Concatenating cross layers. \n",
        "  \u003c/div\u003e\n",
        "  \u003c/center\u003e\n",
        "\u003c/div\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GEi9PtCEdyma"
      },
      "source": [
        "### Model understanding\n",
        "\n",
        "The weight matrix $W$ in DCN reveals what feature crosses the model has learned to be important. Recall that in the previous toy example, the importance of interactions between the $i$-th and $j$-th features is captured by the ($i, j$)-th element of $W$.\n",
        "\n",
        "What's a bit different here is that the feature embeddings are of size 32 instead of size 1. Hence, the importance will be characterized by the $(i, j)$-th block\n",
        "$W_{i,j}$ which is of dimension 32 by 32.\n",
        "In the following, we visualize the Frobenius norm [[4](https://en.wikipedia.org/wiki/Matrix_norm)] $||W_{i,j}||_F$ of each block, and a larger norm would suggest higher importance (assuming the features' embeddings are of similar scales).\n",
        "\n",
        "Besides block norm, we could also visualize the entire matrix, or the mean/median/max value of each block."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "47ibaEBJxOoe"
      },
      "outputs": [],
      "source": [
        "model = dcn_result[\"model\"][0]\n",
        "mat = model._cross_layer._dense.kernel\n",
        "features = model._all_features\n",
        "\n",
        "block_norm = np.ones([len(features), len(features)])\n",
        "\n",
        "dim = model.embedding_dimension\n",
        "\n",
        "# Compute the norms of the blocks.\n",
        "for i in range(len(features)):\n",
        "  for j in range(len(features)):\n",
        "    block = mat[i * dim:(i + 1) * dim,\n",
        "                j * dim:(j + 1) * dim]\n",
        "    block_norm[i,j] = np.linalg.norm(block, ord=\"fro\")\n",
        "\n",
        "plt.figure(figsize=(9,9))\n",
        "im = plt.matshow(block_norm, cmap=plt.cm.Blues)\n",
        "ax = plt.gca()\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "cax = divider.append_axes(\"right\", size=\"5%\", pad=0.05)\n",
        "plt.colorbar(im, cax=cax)\n",
        "cax.tick_params(labelsize=10) \n",
        "_ = ax.set_xticklabels([\"\"] + features, rotation=45, ha=\"left\", fontsize=10)\n",
        "_ = ax.set_yticklabels([\"\"] + features, fontsize=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pQH-moYd6ZKC"
      },
      "source": [
        "That's all for this colab! We hope that you have enjoyed learning some basics of DCN and common ways to utilize it. If you are interested in learning more, you could check out two relevant papers: [DCN-v1-paper](https://arxiv.org/pdf/1708.05123.pdf), [DCN-v2-paper](https://arxiv.org/pdf/2008.13535.pdf)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FfAGbq2es2Yn"
      },
      "source": [
        "---\n",
        "\n",
        "\n",
        "##References\n",
        "[DCN V2: Improved Deep \u0026 Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/pdf/2008.13535.pdf). \\\n",
        "*Ruoxi Wang, Rakesh Shivanna, Derek Zhiyuan Cheng, Sagar Jain, Dong Lin, Lichan Hong, Ed Chi. (2020)*\n",
        "\n",
        "\n",
        "[Deep \u0026 Cross Network for Ad Click Predictions](https://arxiv.org/pdf/1708.05123.pdf). \\\n",
        "*Ruoxi Wang, Bin Fu, Gang Fu, Mingliang Wang. (AdKDD 2017)*"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "dcn.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
